{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 文本分类(Text classification)\n",
    "文本分类任务是将一句话或一段话划分到某个具体的类别。比如垃圾邮件识别，文本情绪分类等。\n",
    "\n",
    "Example::   \n",
    "1,商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!\n",
    "\n",
    "\n",
    "其中开头的1是只这条评论的标签，表示是正面的情绪。我们将使用到的数据可以通过http://dbcloud.irocn.cn:8989/api/public/dl/dataset/chn_senti_corp.zip 下载并解压，当然也可以通过fastNLP自动下载该数据。\n",
    "\n",
    "数据中的内容如下图所示。接下来，我们将用fastNLP在这个数据上训练一个分类网络。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![jupyter](./cn_cls_example.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 步骤\n",
    "一共有以下的几个步骤  \n",
    "(1) 读取数据  \n",
    "(2) 预处理数据  \n",
    "(3) 选择预训练词向量  \n",
    "(4) 创建模型  \n",
    "(5) 训练模型  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (1) 读取数据\n",
    "fastNLP提供多种数据的自动下载与自动加载功能，对于这里我们要用到的数据，我们可以用\\ref{Loader}自动下载并加载该数据。更多有关Loader的使用可以参考\\ref{Loader}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastNLP.io import ChnSentiCorpLoader\n",
    "\n",
    "loader = ChnSentiCorpLoader()  # 初始化一个中文情感分类的loader\n",
    "data_dir = loader.download()  # 这一行代码将自动下载数据到默认的缓存地址, 并将该地址返回\n",
    "data_bundle = loader.load(data_dir)  # 这一行代码将从{data_dir}处读取数据至DataBundle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DataBundle的相关介绍，可以参考\\ref{}。我们可以打印该data_bundle的基本信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "In total 3 datasets:\n",
      "\tdev has 1200 instances.\n",
      "\ttrain has 9600 instances.\n",
      "\ttest has 1200 instances.\n",
      "In total 0 vocabs:\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(data_bundle)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看出，该data_bundle中一个含有三个\\ref{DataSet}。通过下面的代码，我们可以查看DataSet的基本情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataSet({'raw_chars': 选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般 type=str,\n",
      "'target': 1 type=str},\n",
      "{'raw_chars': 15.4寸笔记本的键盘确实爽，基本跟台式机差不多了，蛮喜欢数字小键盘，输数字特方便，样子也很美观，做工也相当不错 type=str,\n",
      "'target': 1 type=str})\n"
     ]
    }
   ],
   "source": [
    "print(data_bundle.get_dataset('train')[:2])  # 查看Train集前两个sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (2) 预处理数据\n",
    "在NLP任务中，预处理一般包括: (a)将一整句话切分成汉字或者词; (b)将文本转换为index  \n",
    "\n",
    "fastNLP中也提供了多种数据集的处理类，这里我们直接使用fastNLP的ChnSentiCorpPipe。更多关于Pipe的说明可以参考\\ref{Pipe}。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastNLP.io import ChnSentiCorpPipe\n",
    "\n",
    "pipe = ChnSentiCorpPipe()\n",
    "data_bundle = pipe.process(data_bundle)  # 所有的Pipe都实现了process()方法，且输入输出都为DataBundle类型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "In total 3 datasets:\n",
      "\tdev has 1200 instances.\n",
      "\ttrain has 9600 instances.\n",
      "\ttest has 1200 instances.\n",
      "In total 2 vocabs:\n",
      "\tchars has 4409 entries.\n",
      "\ttarget has 2 entries.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(data_bundle)  # 打印data_bundle，查看其变化"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到除了之前已经包含的3个\\ref{DataSet}, 还新增了两个\\ref{Vocabulary}。我们可以打印DataSet中的内容"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DataSet({'raw_chars': 选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般 type=str,\n",
      "'target': 1 type=int,\n",
      "'chars': [338, 464, 1400, 784, 468, 739, 3, 289, 151, 21, 5, 88, 143, 2, 9, 81, 134, 2573, 766, 233, 196, 23, 536, 342, 297, 2, 405, 698, 132, 281, 74, 744, 1048, 74, 420, 387, 74, 412, 433, 74, 2021, 180, 8, 219, 1929, 213, 4, 34, 31, 96, 363, 8, 230, 2, 66, 18, 229, 331, 768, 4, 11, 1094, 479, 17, 35, 593, 3, 1126, 967, 2, 151, 245, 12, 44, 2, 6, 52, 260, 263, 635, 5, 152, 162, 4, 11, 336, 3, 154, 132, 5, 236, 443, 3, 2, 18, 229, 761, 700, 4, 11, 48, 59, 653, 2, 8, 230] type=list,\n",
      "'seq_len': 106 type=int},\n",
      "{'raw_chars': 15.4寸笔记本的键盘确实爽，基本跟台式机差不多了，蛮喜欢数字小键盘，输数字特方便，样子也很美观，做工也相当不错 type=str,\n",
      "'target': 1 type=int,\n",
      "'chars': [50, 133, 20, 135, 945, 520, 343, 24, 3, 301, 176, 350, 86, 785, 2, 456, 24, 461, 163, 443, 128, 109, 6, 47, 7, 2, 916, 152, 162, 524, 296, 44, 301, 176, 2, 1384, 524, 296, 259, 88, 143, 2, 92, 67, 26, 12, 277, 269, 2, 188, 223, 26, 228, 83, 6, 63] type=list,\n",
      "'seq_len': 56 type=int})\n"
     ]
    }
   ],
   "source": [
    "print(data_bundle.get_dataset('train')[:2])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "新增了一列为数字列表的chars，以及变为数字的target列。可以看出这两列的名称和刚好与data_bundle中两个Vocabulary的名称是一致的，我们可以打印一下Vocabulary看一下里面的内容。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocabulary(['选', '择', '珠', '江', '花']...)\n"
     ]
    }
   ],
   "source": [
    "char_vocab = data_bundle.get_vocab('chars')\n",
    "print(char_vocab)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Vocabulary是一个记录着词语与index之间映射关系的类，比如"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "'选'的index是338\n",
      "index:338对应的汉字是选\n"
     ]
    }
   ],
   "source": [
    "index = char_vocab.to_index('选')\n",
    "print(\"'选'的index是{}\".format(index))  # 这个值与上面打印出来的第一个instance的chars的第一个index是一致的\n",
    "print(\"index:{}对应的汉字是{}\".format(index, char_vocab.to_word(index))) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (3) 选择预训练词向量  \n",
    "由于Word2vec, Glove, Elmo, Bert等预训练模型可以增强模型的性能，所以在训练具体任务前，选择合适的预训练词向量非常重要。在fastNLP中我们提供了多种Embedding使得加载这些预训练模型的过程变得更加便捷。更多关于Embedding的说明可以参考\\ref{Embedding}。这里我们先给出一个使用word2vec的中文汉字预训练的示例，之后再给出一个使用Bert的文本分类。这里使用的预训练词向量为'cn-fastnlp-100d'，fastNLP将自动下载该embedding至本地缓存，fastNLP支持使用名字指定的Embedding以及相关说明可以参见\\ref{Embedding}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found 4321 out of 4409 words in the pre-training embedding.\n"
     ]
    }
   ],
   "source": [
    "from fastNLP.embeddings import StaticEmbedding\n",
    "\n",
    "word2vec_embed = StaticEmbedding(char_vocab, model_dir_or_name='cn-char-fastnlp-100d')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (4) 创建模型\n",
    "这里我们使用到的模型结构如下所示，补图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from fastNLP.modules import LSTM\n",
    "import torch\n",
    "\n",
    "# 定义模型\n",
    "class BiLSTMMaxPoolCls(nn.Module):\n",
    "    def __init__(self, embed, num_classes, hidden_size=400, num_layers=1, dropout=0.3):\n",
    "        super().__init__()\n",
    "        self.embed = embed\n",
    "        \n",
    "        self.lstm = LSTM(self.embed.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers, \n",
    "                         batch_first=True, bidirectional=True)\n",
    "        self.dropout_layer = nn.Dropout(dropout)\n",
    "        self.fc = nn.Linear(hidden_size, num_classes)\n",
    "        \n",
    "    def forward(self, chars, seq_len):  # 这里的名称必须和DataSet中相应的field对应，比如之前我们DataSet中有chars，这里就必须为chars\n",
    "        # chars:[batch_size, max_len]\n",
    "        # seq_len: [batch_size, ]\n",
    "        chars = self.embed(chars)\n",
    "        outputs, _ = self.lstm(chars, seq_len)\n",
    "        outputs = self.dropout_layer(outputs)\n",
    "        outputs, _ = torch.max(outputs, dim=1)\n",
    "        outputs = self.fc(outputs)\n",
    "        \n",
    "        return {'pred':outputs}  # [batch_size,], 返回值必须是dict类型，且预测值的key建议设为pred\n",
    "\n",
    "# 初始化模型\n",
    "model = BiLSTMMaxPoolCls(word2vec_embed, len(data_bundle.get_vocab('target')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (5) 训练模型\n",
    "fastNLP提供了Trainer对象来组织训练过程，包括完成loss计算(所以在初始化Trainer的时候需要指定loss类型)，梯度更新(所以在初始化Trainer的时候需要提供优化器optimizer)以及在验证集上的性能验证(所以在初始化时需要提供一个Metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input fields after batch(if batch size is 2):\n",
      "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n",
      "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "target fields after batch(if batch size is 2):\n",
      "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\n",
      "Evaluate data in 0.01 seconds!\n",
      "training epochs started 2019-09-03-23-57-10\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3000), HTML(value='')), layout=Layout(display…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.43 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 1/10. Step:300/3000: \n",
      "\r",
      "AccuracyMetric: acc=0.81\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.44 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 2/10. Step:600/3000: \n",
      "\r",
      "AccuracyMetric: acc=0.8675\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.44 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 3/10. Step:900/3000: \n",
      "\r",
      "AccuracyMetric: acc=0.878333\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.43 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 4/10. Step:1200/3000: \n",
      "\r",
      "AccuracyMetric: acc=0.873333\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.44 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 5/10. Step:1500/3000: \n",
      "\r",
      "AccuracyMetric: acc=0.878333\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.42 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 6/10. Step:1800/3000: \n",
      "\r",
      "AccuracyMetric: acc=0.895833\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.44 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 7/10. Step:2100/3000: \n",
      "\r",
      "AccuracyMetric: acc=0.8975\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.43 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 8/10. Step:2400/3000: \n",
      "\r",
      "AccuracyMetric: acc=0.894167\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.48 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 9/10. Step:2700/3000: \n",
      "\r",
      "AccuracyMetric: acc=0.8875\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=38), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.43 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 10/10. Step:3000/3000: \n",
      "\r",
      "AccuracyMetric: acc=0.895833\n",
      "\n",
      "\r\n",
      "In Epoch:7/Step:2100, got best dev performance:\n",
      "AccuracyMetric: acc=0.8975\n",
      "Reloaded the best model.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 0.34 seconds!\n",
      "[tester] \n",
      "AccuracyMetric: acc=0.8975\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'AccuracyMetric': {'acc': 0.8975}}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from fastNLP import Trainer\n",
    "from fastNLP import CrossEntropyLoss\n",
    "from torch.optim import Adam\n",
    "from fastNLP import AccuracyMetric\n",
    "\n",
    "loss = CrossEntropyLoss()\n",
    "optimizer = Adam(model.parameters(), lr=0.001)\n",
    "metric = AccuracyMetric()\n",
    "device = 0 if torch.cuda.is_available() else 'cpu'  # 如果有gpu的话在gpu上运行，训练速度会更快\n",
    "\n",
    "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
    "                  optimizer=optimizer, batch_size=32, dev_data=data_bundle.get_dataset('dev'),\n",
    "                  metrics=metric, device=device)\n",
    "trainer.train()  # 开始训练，训练完成之后默认会加载在dev上表现最好的模型\n",
    "\n",
    "# 在测试集上测试一下模型的性能\n",
    "from fastNLP import Tester\n",
    "print(\"Performance on test is:\")\n",
    "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
    "tester.test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用Bert进行文本分类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading vocabulary file /home/yh/.fastNLP/embedding/bert-chinese-wwm/vocab.txt\n",
      "Load pre-trained BERT parameters from file /home/yh/.fastNLP/embedding/bert-chinese-wwm/chinese_wwm_pytorch.bin.\n",
      "Start to generating word pieces for word.\n",
      "Found(Or segment into word pieces) 4286 words out of 4409.\n",
      "input fields after batch(if batch size is 2):\n",
      "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\tchars: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2, 106]) \n",
      "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "target fields after batch(if batch size is 2):\n",
      "\ttarget: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\tseq_len: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2]) \n",
      "\n",
      "Evaluate data in 0.05 seconds!\n",
      "training epochs started 2019-09-04-00-02-37\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=3600), HTML(value='')), layout=Layout(display…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 15.89 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 1/3. Step:1200/3600: \n",
      "\r",
      "AccuracyMetric: acc=0.9\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 15.92 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 2/3. Step:2400/3600: \n",
      "\r",
      "AccuracyMetric: acc=0.904167\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=150), HTML(value='')), layout=Layout(display=…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 15.91 seconds!\n",
      "\r",
      "Evaluation on dev at Epoch 3/3. Step:3600/3600: \n",
      "\r",
      "AccuracyMetric: acc=0.918333\n",
      "\n",
      "\r\n",
      "In Epoch:3/Step:3600, got best dev performance:\n",
      "AccuracyMetric: acc=0.918333\n",
      "Reloaded the best model.\n",
      "Performance on test is:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=19), HTML(value='')), layout=Layout(display='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r",
      "Evaluate data in 29.24 seconds!\n",
      "[tester] \n",
      "AccuracyMetric: acc=0.919167\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'AccuracyMetric': {'acc': 0.919167}}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 只需要切换一下Embedding即可\n",
    "from fastNLP.embeddings import BertEmbedding\n",
    "\n",
    "# 这里为了演示一下效果，所以默认Bert不更新权重\n",
    "bert_embed = BertEmbedding(char_vocab, model_dir_or_name='cn', auto_truncate=True, requires_grad=False)\n",
    "model = BiLSTMMaxPoolCls(bert_embed, len(data_bundle.get_vocab('target')), )\n",
    "\n",
    "\n",
    "import torch\n",
    "from fastNLP import Trainer\n",
    "from fastNLP import CrossEntropyLoss\n",
    "from torch.optim import Adam\n",
    "from fastNLP import AccuracyMetric\n",
    "\n",
    "loss = CrossEntropyLoss()\n",
    "optimizer = Adam(model.parameters(), lr=2e-5)\n",
    "metric = AccuracyMetric()\n",
    "device = 0 if torch.cuda.is_available() else 'cpu'  # 如果有gpu的话在gpu上运行，训练速度会更快\n",
    "\n",
    "trainer = Trainer(train_data=data_bundle.get_dataset('train'), model=model, loss=loss, \n",
    "                  optimizer=optimizer, batch_size=16, dev_data=data_bundle.get_dataset('test'),\n",
    "                  metrics=metric, device=device, n_epochs=3)\n",
    "trainer.train()  # 开始训练，训练完成之后默认会加载在dev上表现最好的模型\n",
    "\n",
    "# 在测试集上测试一下模型的性能\n",
    "from fastNLP import Tester\n",
    "print(\"Performance on test is:\")\n",
    "tester = Tester(data=data_bundle.get_dataset('test'), model=model, metrics=metric, batch_size=64, device=device)\n",
    "tester.test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
